% y = f_W(x) = W x, where W = eye
% y = [y1 y2 y3], delta = [d1 d2 d3]
% f_W(x+delta) = x+delta = y+delta = [y1+d1 y2+d2 y3+d3]
% consider y = [y1 y2 y3] = [3 2 1]. also assume the correct label is y1
% xent loss: log(exp(y1+d1)+exp(y2+d2)+exp(y3+d3)) - (y1+d1)
% xent loss: log(1+exp(y2+d2-y1-d1)+exp(y3+d3-y1-d1))
% delta_* = [-.501 .5 0] is enough to flip the label
% y + delta_* = [2.499 2.5 0] so that the prediction is y2
% || delta_* || = .7078
% we just grid search over the range (-.501, +.501)^3
% and of course exclude delta if || delta || > .7078
% we then look at the max of L and max of P
% where L is the xent loss and P is 0-1 loss
% and print delta that achieved these max values

clear all
y1=3; y2=2; y3=1;
% CHANGE THE FOLLOWING LINE TO d = linspace(-.51,.51,800); TO GET EXACT
% SAME RESULTS AS IN THE PAPER
d = linspace(-.51,.51,400);
l=length(d);
b = norm([-.51 .5 0]);
L=zeros(l,l,l); % L is xent loss
X=zeros(l,l,l); % L is xent loss
P=false(l,l,l); % P is 0-1 loss
for i = 1:l
    i
    for j = 1:l
        for k = 1:l
            if norm([d(i) d(j) d(k)]) <= b
                % calculating xent loss
                L(i,j,k)=log(1+exp(y2+d(j)-y1-d(i)+exp(y3+d(k)-y1-d(i))));
                X(i,j,k) = -log( exp(y1+d(i)-y2-d(j)) + 1 + exp(y3+d(k)-y2-d(j)) );
%                 l3 = -log( exp(y1+d(i)-y3-d(k)) + exp(y2+d(j)-y3-d(k)) + 1 );                
%                 X(i,j,k)=max(l2,l3);
                % calculating 0-1 error
                if y1+d(i) < y2+d(j)
                    P(i,j,k)=true;
                end
                if y1+d(i) < y3+d(k)
                    P(i,j,k)=true;
                end
            else
                L(i,j,k) = -1;
                X(i,j,k) = -1;
            end
        end
    end
end

save('mem800.mat', 'L', 'P', 'X', '-v7.3');

sol_size = sum(P==true,'all');
% max_size = 10000;
[mL, nL] = sort(L(:),'descend');
disp('done')
[mX, nX] = sort(X(:),'descend');
disp('done')
max_size = length(L(:));
% [mL, nL] = maxk(L(:),max_size);
% [mX, nX] = maxk(X(:),max_size);
[xL, yL, zL] = ind2sub(size(L),nL);
disp('done')
[xX, yX, zX] = ind2sub(size(X),nX);
disp('done')
save('stat800.mat', 'xL', 'yL', 'zL', 'xX', 'yX', 'zX', '-v7.3');
countL = 0;
countX = 0;
% Lsum = zeros(1,max_size);
% Xsum = zeros(1,max_size);
Lrat = zeros(1,max_size);
Xrat = zeros(1,max_size);
for i = 1:max_size
    if mod(i, 10000000) == 1
        i
    end
    if P(xL(i),yL(i),zL(i)) == true
        countL = countL + 1;
%         Lsum(i) = countL;
    end
    Lrat(i) = countL/sol_size;
    if P(xX(i),yX(i),zX(i)) == true
        countX = countX + 1;
%         Xsum(i) = countX;
    end
    Xrat(i) = countX/sol_size;
end
save('plot800.mat', 'Lrat', 'Xrat', '-v7.3');

% [sol_size, countL, countL/sol_size]
% 
% disp('maximizing negated loss')
% [mX, nX] = maxk(X(:),sol_size);
% [xX, yX, zX] = ind2sub(size(X),nX);
% countX = 0;
% for i = 1:sol_size
%     if P(xX(i),yX(i),zX(i)) == true
%         countX = countX + 1;
%     end
% end
% [sol_size, countX, countX/sol_size]


% 
% disp('maximizing xent loss')
% [m, n] = max(L(:));
% [x, y, z] = ind2sub(size(L),n);
% disp(['delta = ', num2str(d(x)), ', ', num2str(d(y)), ', ', num2str(d(z))])
% % OUTPUT: delta = -0.501, 0.48194, 0.1329
% disp(['xent loss = ', num2str(L(x,y,z))])
% % OUTPUT: xent loss = 0.81923
% disp(['0-1 loss = ', num2str(P(x,y,z))])
% % OUTPUT: 0-1 loss = 0
% 
% disp('maximizing negated loss')
% [m, n] = max(X(:));
% [x, y, z] = ind2sub(size(X),n);
% disp(['delta = ', num2str(d(x)), ', ', num2str(d(y)), ', ', num2str(d(z))])
% % OUTPUT: delta = -0.501, 0.48194, 0.1329
% disp(['xent loss = ', num2str(L(x,y,z))])
% % OUTPUT: xent loss = 0.81923
% disp(['0-1 loss = ', num2str(P(x,y,z))])
% % OUTPUT: 0-1 loss = 0
% 
% [m, n] = max(P(:));
% [x, y, z] = ind2sub(size(P),n);
% disp('maximizing 0-1 loss')
% disp(['delta = ', num2str(d(x)), ', ', num2str(d(y)), ', ', num2str(d(z))])
% % OUTPUT: delta = -0.501, 0.5, -0.0015045
% disp(['xent loss = ', num2str(L(x,y,z))])
% % OUTPUT: xent loss = 0.81141
% disp(['0-1 loss = ', num2str(P(x,y,z))])
% % OUTPUT: 0-1 loss = 1